Who said looking at traces wasn't fun? This post aims to be an extension to the API section on Therefore, the most common arguments such as But what about having both rug and divergences at the same time? Fear not, ArviZ automatically modifies the default for divergences from lines : list of tuple of (str, dict, array_like), optional
List of It is possible that the first thought after reading this line is similar to "What is with this weird format?" Well, this format is actually the stardard way ArviZ uses to iterate over This section will be a little different from the other ones, and will focus on boosting Let's see what Now that we know about And what about quantile lines? Lets plot the 10% and 90% quantile lines but only for defs variable: This section is dedicated to 5 different kwargs, closely related to each other: In We'll now cover all 4 possibilities to showcase all supported cases and explore related customizations. Chains are aggregated into a single quantity if possible. Therefore, distribution column will have one line per subplot due to the aggregation but the trace column will be the same as in the previous section.
This is also a chain only setting, the default mapping is to use color to distinguish chains. However, we'll use this example to show usage of The first two things that jump to the eye are that ArviZ has drastically modified the default aesthetic of the plot and that the plot fits now comfortable in a single screen, bye bye scrolling We can also see that To reduce even more the clutter of lines in the trace plot, we can also combine chains. Moreover, the Finally, we will explore alternative usage options for Now that we have covered most arguments, let's put everything to practice. Try to generate a trace plot following the instructions below: Comments are not enabled for the blog, to inquiry further about the contents of the post, ask on ArviZ Issues or PyMC DiscourseArviZ in depth: plot_trace
Introduction
plot_trace is one of the most common plots to assess the convergence of MCMC runs, therefore, it is also one of the most used ArviZ functions. plot_trace has a lot of parameters that allow creating highly customizable plots, but they may not be straightforward to use. There are many reasons that can explain this convolutedness of the arguments and their format, there is no clear culprit: ArviZ has to integrate with several libraries such as xarray and matplotlib which provide amazing features and customization power, and we'd like to allow ArviZ users to access all these features. However, we also aim to keep ArviZ usage simple and with sensible defaults; plot_xyz(idata) should generate acceptable results in most situations.plot_trace, focusing mostly on arguments where examples may be lacking and arguments that appear often in questions posted to ArviZ issues.var_names will not be covered, and for arguments that I do not remeber appearing in issues or generating confusion only some examples will be shown without an in depth description.import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
# html render is not correctly rendered in blog,
# comment the line below if in jupyter
xr.set_options(display_style="text")
rng = np.random.default_rng()
az.style.use("arviz-darkgrid")
idata_centered = az.load_arviz_data("centered_eight")
idata = az.load_arviz_data("rugby")
The
kind argument
az.plot_trace generates two columns. The left one calls plot_dist to plot KDE/Histogram of the data, and the right column can contain either the trace itself (which gives the name to the plot) or a rank plot for which two visualizations are available. Rank plots are an alternative to trace plots, see https://arxiv.org/abs/1903.08008 for more details.fig, axes = plt.subplots(3,2, figsize=(12,6))
for i, kind in enumerate(("trace", "rank_bars", "rank_vlines")):
az.plot_trace(idata, var_names="home", kind=kind, ax=axes[i,:]);
fig.tight_layout()
az.plot_trace(idata_centered, var_names="tau");
az.plot_trace(idata_centered, var_names="tau", divergences=None);
ax = az.plot_trace(idata, var_names="home", rug=True, rug_kwargs={"alpha": .4})
bottom to top to prevent rug and divergences from overlapping:az.plot_trace(idata_centered, var_names="mu", rug=True);
(var_name, {‘coord’: selection}, [line, positions]) to be overplotted as vertical lines on the density and horizontal lines on the trace.xarray.Dataset objects because it contains all the info about the variable and the selected coordinates as well as the values themselves. The main helper function that handles this is arviz.plots.plot_utils.xarray_var_iter.plot_trace capabilities with internal ArviZ functions. You may want to skip to the section altogether of go straigh to the end.xarray_var_iter does with a simple dataset. We will create a dataset with two variables: a will be a 2x3 matrix and b will be a scalar. In addition, the dimensions of a will be labeled.ds = xr.Dataset({
"a": (("pos", "direction"), rng.normal(size=(2,3))),
"b": 12,
"pos": ["top", "bottom"],
"direction": ["x", "y", "z"]
})
ds
from arviz.plots.plot_utils import xarray_var_iter
for var_name, sel, values in xarray_var_iter(ds):
print(var_name, sel, values)
xarray_var_iter has iterated over every single scalar value without loosing track of where did every value come from. We can also modify the behaviour to skip some dimensions (i.e. in ArviZ we generally iterate over data dimensions and skip chain and draw dims).for var_name, sel, values in xarray_var_iter(ds, skip_dims={"direction"}):
print(var_name, sel, values)
xarray_var_iter and what it does, we can use it to generate a list in the required format directly from xarray objects. Let's say for example we were interested in plotting the mean as a line in the trace plot:var_names = ["home", "atts"]
lines = list(xarray_var_iter(idata.posterior[var_names].mean(dim=("chain", "draw"))))
az.plot_trace(idata, var_names=var_names, lines=lines);
az.hdi skipping hdi dimension
var_names = ["home", "defs"]
quantile_ds = idata.posterior[["defs"]].quantile((.1, .9), dim=("chain", "draw"))
lines = list(xarray_var_iter(quantile_ds, skip_dims={"quantile"}))
az.plot_trace(idata, var_names=var_names, lines=lines);
Aggregation kwargs
compact+compact_prop, combined+chain_prop and legend. If we focus on the distribution plots of the left column, we may want to aggregate data along 2 possible dimensions, chains or variable dimension(s) -- school dimension in centered_eight data, team dimension in rugby data... As aggragation or not along these 2 possible dimensions is independent, we end up with 4 possibilities.az.plot_trace, the argument combined governs the aggregation of all chains into a single plot (has no effect in trace, only in distributions), and compact governs the aggregation of the variable dimension(s). In order to be able to distinguish each single line after some aggregation has taken place, a legend argument is also available to show the legend with the data labels. chain_prop and compact_prop allow customization of the aesthetics mapping.az.plot_trace(idata, var_names=["home", "defs"], legend=True);
combined=True and compact=False
chain_prop to map the chain to the linewidth:chain_prop = {"linewidth": (.5, 1, 2, 3)}
az.plot_trace(
idata, var_names=["home", "defs"], combined=True, chain_prop=chain_prop, compact=False, legend=True
);
az.plot_trace(idata, var_names=["home", "defs"], combined=False, compact=True, legend=True);
![]()
legend=True has included multiple legends to the figure. The chain legend is always included in the top right trace plot, and the plots in the distribution column contain a legend if necessary.
combined=True and compact=True
linestyle -> chain mapping can be distracting, especially if we don't care too much about distinguishing the chains between them. Like we did before, we will use chain_prop to control this.az.plot_trace(idata, var_names=["home", "defs"], combined=True, chain_prop={"ls": "-"}, compact=True);
chain_prop and compact_prop. In the two previous examples we have used a 2 element tuple where the second position of the tuple contained the properties to use. Another alternative is to pass a string present in plt.rcParams["axes.prop_cycle"], which in our case is color only.az.plot_trace(
idata, var_names=["home", "defs"],
combined=True, chain_prop="color",
compact=True, compact_prop={"lw": np.linspace(.5, 3, 6)}
);
Summing it all up
home, defs and atts showing only Scotland, Ireland, Italy, Wales coordinates.defs variable, plot lines showing the 70% HDI.'C0', 'C1', 'C2', 'C3', "xkcd:purple blue"
#collapse-hide
coords = {"team": ["Scotland", "Ireland", "Italy", "Wales"]}
quantile_ds = az.hdi(idata, var_names="defs", coords=coords, hdi_prob=.7)
lines = list(xarray_var_iter(quantile_ds, skip_dims={"hdi", "team"}))
chain_prop = {"color": ['C0', 'C1', 'C2', 'C3', "xkcd:purple blue"]}
az.plot_trace(
idata, var_names=["home", "defs", "atts"],
combined=True, chain_prop=chain_prop,
compact=True, compact_prop={"lw": np.linspace(.5, 3, 6), "ls": ("-", "--")},
lines=lines,
coords=coords
);
Warning: Post still in progress!
Note: This same approach can also be used with